Skip to content

[BugFix] Directly Convert Modifiers to Recipe Instance#1271

Merged
dsikka merged 5 commits intomainfrom
bugfix-modifier-parsing
Apr 2, 2025
Merged

[BugFix] Directly Convert Modifiers to Recipe Instance#1271
dsikka merged 5 commits intomainfrom
bugfix-modifier-parsing

Conversation

@rahul-tuli
Copy link
Copy Markdown
Collaborator

Currently, the process of recipe creation follows this sequence:

Modifiers → String (Serialization) → Recipe Instance (Deserialization)

This intermediate serialization and deserialization step introduces issues when dealing with more complex objects, such as SmoothQuant mappings, which can lead to parsing errors.

Solution

This PR refactors the flow to directly construct the Recipe Instance from Modifiers, thereby removing an unnecessary conversion step and eliminating a potential source of error.

Issue Tracking

This issue was originally surfaced in [vllm-project/llm-compressor#37](#37) and is formally tracked under [INFERENG-358](https://issues.redhat.com/browse/INFERENG-358).

Testing

The issue was reproduced using the following script, which previously errored out but now runs successfully with this fix:

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier

DATASET_ID = "HuggingFaceH4/ultrachat_200k"
MODEL_ID = "bigscience/bloom-3b"
DATASET_SPLIT = "train_sft"
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, device_map="auto", torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Define quantization recipe
recipe = [
    SmoothQuantModifier(
        smoothing_strength=0.8,
        mappings=[
            (["re:.*query_key_value"], "re:.*input_layernorm"),
            (["re:.*dense_h_to_4h"], "re:.*post_attention_layernorm"),
        ],
    ),
    GPTQModifier(
        scheme="W8A8",
        targets="Linear",
        ignore=["lm_head"],
        dampening_frac=0.003,
    ),
]

# Load and preprocess dataset
dataset = load_dataset(DATASET_ID, split=DATASET_SPLIT)
dataset = dataset.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))

def preprocess(example):
    """Formats the messages into a simple dialogue format."""
    text = "\n".join([msg["content"] for msg in example["messages"]])
    return {"text": text}

dataset = dataset.map(preprocess)

# Apply quantization
oneshot(
    model=model,
    dataset=dataset,
    recipe=recipe,
    output_dir="bloom-3b-gptq-w8a8",
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

With this fix, the script now runs to completion without errors. Automated tests have also been added to test new changes

… Recipe

Add e2e tests for recipe parsing

Signed-off-by: Rahul Tuli <rahul@neuralmagic.com>
@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

@rahul-tuli rahul-tuli added the ready When a PR is ready for review label Mar 20, 2025
Copy link
Copy Markdown
Collaborator

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good! couple suggestions based on the parts i understand

Signed-off-by: Rahul Tuli <rahul@neuralmagic.com>
Signed-off-by: Rahul Tuli <rahul@neuralmagic.com>
@rahul-tuli rahul-tuli self-assigned this Mar 25, 2025
Copy link
Copy Markdown
Collaborator

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool, thanks!

@rahul-tuli rahul-tuli enabled auto-merge (squash) March 26, 2025 14:29
Copy link
Copy Markdown
Collaborator

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

@dsikka dsikka disabled auto-merge April 2, 2025 16:48
@dsikka dsikka enabled auto-merge (squash) April 2, 2025 17:23
@dsikka dsikka merged commit 027caa4 into main Apr 2, 2025
8 checks passed
@dsikka dsikka deleted the bugfix-modifier-parsing branch April 2, 2025 17:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready When a PR is ready for review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants